# Modified from RFdiffusion (https://github.com/RosettaCommons/RFdiffusion)
# Original license: BSD License
#
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================


import logging
import os
import pickle
import string

import mindspore as ms
import numpy as np
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf

from .. import util
from ..chemical import seq2chars
from ..contigs import ContigMap
from ..diffusion import Diffuser
from ..kinematics import get_init_xyz, xyz_to_t2d
from ..model_input_logger import pickle_function_call
from ..potentials.manager import PotentialManager
from ..RoseTTAFoldModel import RoseTTAFoldModule
from ..util_module import ComputeAllAtomCoords
from . import ab_pose, symmetry
from . import utils as iu
from .ab_util import (
    correct_selfcond,
    featurize,
    process_init_selfcond,
    process_selfcond,
)

SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))

TOR_INDICES = util.torsion_indices
TOR_CAN_FLIP = util.torsion_can_flip
REF_ANGLES = util.reference_angles


class Sampler:

    def __init__(self, conf: DictConfig):
        """
        Initialize sampler.
        Args:
            conf: Configuration.
        """
        self.initialized = False
        self.initialize(conf)

    def ab_design(self):
        return False

    def initialize(self, conf: DictConfig) -> None:
        """
        Initialize sampler.
        Args:
            conf: Configuration

        - Selects appropriate model from input
        - Assembles Config from model checkpoint and command line overrides

        """
        self._log = logging.getLogger(__name__)
        needs_model_reload = (
            not self.initialized
            or conf.inference.ckpt_override_path
            != self._conf.inference.ckpt_override_path
        )

        # Assign config to Sampler
        self._conf = conf

        ################################
        ### Select Appropriate Model ###
        ################################

        if conf.inference.model_directory_path is not None:
            model_directory = conf.inference.model_directory_path
        else:
            model_directory = f"{SCRIPT_DIR}/../../models"

        print(f"Reading models from {model_directory}")

        # Initialize inference only helper objects to Sampler
        if conf.inference.ckpt_override_path is not None:
            self.ckpt_path = conf.inference.ckpt_override_path
            print(
                "WARNING: You're overriding the checkpoint path from the defaults. Check that the model you're providing can run with the inputs you're providing."
            )
        else:
            if (
                conf.contigmap.inpaint_seq is not None
                or conf.contigmap.provide_seq is not None
                or conf.contigmap.inpaint_str
            ):
                # use model trained for inpaint_seq
                if conf.contigmap.provide_seq is not None:
                    # this is only used for partial diffusion
                    assert (
                        conf.diffuser.partial_T is not None
                    ), "The provide_seq input is specifically for partial diffusion"
                if conf.scaffoldguided.scaffoldguided:
                    self.ckpt_path = f"{model_directory}/InpaintSeq_Fold_ckpt.ckpt"
                else:
                    self.ckpt_path = f"{model_directory}/InpaintSeq_ckpt.ckpt"
            elif (
                conf.ppi.hotspot_res is not None
                and conf.scaffoldguided.scaffoldguided is False
            ):
                # use complex trained model
                self.ckpt_path = f"{model_directory}/Complex_base_ckpt.ckpt"
            elif conf.scaffoldguided.scaffoldguided is True:
                # use complex and secondary structure-guided model
                self.ckpt_path = f"{model_directory}/Complex_Fold_base_ckpt.ckpt"
            else:
                # use default model
                self.ckpt_path = f"{model_directory}/Base_ckpt.ckpt"
        # for saving in trb file:
        assert (
            self._conf.inference.trb_save_ckpt_path is None
        ), "trb_save_ckpt_path is not the place to specify an input model. Specify in inference.ckpt_override_path"
        self._conf["inference"]["trb_save_ckpt_path"] = self.ckpt_path

        #######################
        ### Assemble Config ###
        #######################

        if needs_model_reload:
            # Load checkpoint, so that we can assemble the config
            self.load_checkpoint()
            self.assemble_config_from_chk()
            # Now actually load the model weights into RF
            self.model = self.load_model()
        else:
            self.assemble_config_from_chk()

        name_list = []
        with open("ms_parameters.txt", "w") as f:
            for name, Parameter in self.model.parameters_and_names():
                name_list.append(name)

            for name in sorted(name_list):
                f.write(name + "\n")

        self.initialized = True

        # Initialize helper objects
        self.inf_conf = self._conf.inference
        self.contig_conf = self._conf.contigmap
        self.denoiser_conf = self._conf.denoiser
        self.ppi_conf = self._conf.ppi
        self.potential_conf = self._conf.potentials
        self.diffuser_conf = self._conf.diffuser
        self.preprocess_conf = self._conf.preprocess

        if "antibody" in conf:
            self.ab_conf = self._conf.antibody

        if conf.inference.schedule_directory_path is not None:
            schedule_directory = conf.inference.schedule_directory_path
        else:
            schedule_directory = f"{SCRIPT_DIR}/../../schedules"

        # Check for cache schedule
        if not os.path.exists(schedule_directory):
            os.mkdir(schedule_directory)
        self.diffuser = Diffuser(**self._conf.diffuser, cache_dir=schedule_directory)

        ###########################
        ### Initialise Symmetry ###
        ###########################

        if self.inf_conf.symmetry is not None:
            self.symmetry = symmetry.SymGen(
                self.inf_conf.symmetry,
                self.inf_conf.recenter,
                self.inf_conf.radius,
                self.inf_conf.model_only_neighbors,
            )
        else:
            self.symmetry = None

        self.allatom = ComputeAllAtomCoords()

        if not self.ab_design():
            if self.inf_conf.input_pdb is None:
                # set default pdb
                script_dir = os.path.dirname(os.path.realpath(__file__))
                self.inf_conf.input_pdb = os.path.join(
                    script_dir, "../../examples/input_pdbs/1qys.pdb"
                )
            self.target_feats = iu.process_target(
                self.inf_conf.input_pdb, parse_hetatom=True, center=False
            )
        self.chain_idx = None
        self.idx_pdb = None

        ##############################
        ### Handle Partial Noising ###
        ##############################

        if self.diffuser_conf.partial_T:
            assert self.diffuser_conf.partial_T <= self.diffuser_conf.T
            self.t_step_input = int(self.diffuser_conf.partial_T)
        else:
            self.t_step_input = int(self.diffuser_conf.T)

    @property
    def T(self):
        """
        Return the maximum number of timesteps
        that this design protocol will perform.

        Output:
            T (int): The maximum number of timesteps to perform
        """
        return self.diffuser_conf.T

    def load_checkpoint(self) -> None:
        """Loads RF checkpoint, from which config can be generated."""
        self._log.info(f"Reading checkpoint from {self.ckpt_path}")
        print("This is inf_conf.ckpt_path")
        print(self.ckpt_path)
        with open(self.ckpt_path, "rb") as f:
            self.ckpt = pickle.load(f)

    def assemble_config_from_chk(self) -> None:
        """
        Function for loading model config from checkpoint directly.

        Takes:
            - config file

        Actions:
            - Replaces all -model and -diffuser items
            - Throws a warning if there are items in -model and -diffuser that aren't in the checkpoint

        This throws an error if there is a flag in the checkpoint 'config_dict' that isn't in the inference config.
        This should ensure that whenever a feature is added in the training setup, it is accounted for in the inference script.

        """
        # get overrides to re-apply after building the config from the checkpoint
        overrides = []
        if HydraConfig.initialized():
            overrides = HydraConfig.get().overrides.task
        print("Assembling -model, -diffuser and -preprocess configs from checkpoint")

        for cat in ["model", "diffuser", "preprocess", "antibody"]:
            if cat not in self._conf:
                continue
            for key in self._conf[cat]:
                try:
                    print(
                        f"USING MODEL CONFIG: self._conf[{cat}][{key}] = {self.ckpt['config_dict'][cat][key]}"
                    )
                    self._conf[cat][key] = self.ckpt["config_dict"][cat][key]
                except:
                    pass

        # add overrides back in again
        for override in overrides:
            if override.split(".")[0] in ["model", "diffuser", "preprocess"]:
                print(
                    f'WARNING: You are changing {override.split("=")[0]} from the value this model was trained with. Are you sure you know what you are doing?'
                )
                mytype = type(
                    self._conf[override.split(".")[0]][
                        override.split(".")[1].split("=")[0]
                    ]
                )
                self._conf[override.split(".")[0]][
                    override.split(".")[1].split("=")[0]
                ] = mytype(override.split("=")[1])

    def load_model(self):
        """Create RosettaFold model from preloaded checkpoint."""

        # Read input dimensions from checkpoint.
        self.d_t1d = self._conf.preprocess.d_t1d
        self.d_t2d = self._conf.preprocess.d_t2d
        model = RoseTTAFoldModule(
            **self._conf.model,
            d_t1d=self.d_t1d,
            d_t2d=self.d_t2d,
            T=self._conf.diffuser.T,
        )
        if self._conf.logging.inputs:
            pickle_dir = pickle_function_call(model, "forward", "inference")
            print(f"pickle_dir: {pickle_dir}")
        # model = model.eval()
        model.set_train(False)
        self._log.info("Loading checkpoint.")
        ms.load_param_into_net(model, self.ckpt["model_state_dict"], strict_load=True)
        # model.load_state_dict(self.ckpt['model_state_dict'], strict=True)
        return model

    def construct_contig(self, target_feats):
        """
        Construct contig class describing the protein to be generated
        """
        self._log.info(f"Using contig: {self.contig_conf.contigs}")
        return ContigMap(target_feats, **self.contig_conf)

    def construct_denoiser(self, L, visible):
        """Make length-specific denoiser."""
        denoise_kwargs = OmegaConf.to_container(self.diffuser_conf)
        denoise_kwargs.update(OmegaConf.to_container(self.denoiser_conf))
        denoise_kwargs.update(
            {
                "L": L,
                "diffuser": self.diffuser,
                "potential_manager": self.potential_manager,
            }
        )
        return iu.Denoise(**denoise_kwargs)

    def sample_init(self, return_forward_trajectory=False):
        """
        Initial features to start the sampling process.

        Modify signature and function body for different initialization
        based on the config.

        Returns:
            xt: Starting positions with a portion of them randomly sampled.
            seq_t: Starting sequence with a portion of them set to unknown.
        """

        #######################
        ### Parse input pdb ###
        #######################

        self.target_feats = iu.process_target(
            self.inf_conf.input_pdb, parse_hetatom=True, center=False
        )

        ################################
        ### Generate specific contig ###
        ################################

        # Generate a specific contig from the range of possibilities specified at input

        self.contig_map = self.construct_contig(self.target_feats)
        self.mappings = self.contig_map.get_mappings()
        self.mask_seq = ms.from_numpy(self.contig_map.inpaint_seq)[None, :]
        self.mask_str = ms.from_numpy(self.contig_map.inpaint_str)[None, :]
        self.binderlen = len(self.contig_map.inpaint)

        #######################################
        ### Resolve cyclic peptide indices ###
        #######################################
        if self._conf.inference.cyclic:
            if self._conf.inference.cyc_chains is None:
                # default to all residues being cyclized
                self.cyclic_reses = ~self.mask_str.squeeze()
            else:
                # use cyc_chains arg to determine cyclic_reses mask
                assert (
                    type(self._conf.inference.cyc_chains) is str
                ), "cyc_chains arg must be string"
                cyc_chains = self._conf.inference.cyc_chains
                cyc_chains = [i.upper() for i in cyc_chains]
                hal_idx = (
                    self.contig_map.hal
                )  # the pdb indices of output, knowledge of different chains
                is_cyclized = (
                    ms.mint.zeros_like(self.mask_str).bool().squeeze()
                )  # initially empty

                for ch in cyc_chains:
                    ch_mask = ms.tensor([idx[0] == ch for idx in hal_idx]).bool()
                    # set this whole chain to be cyclic
                    is_cyclized[ch_mask] = True
                self.cyclic_reses = is_cyclized
        else:
            self.cyclic_reses = ms.mint.zeros_like(self.mask_str).bool().squeeze()

        ####################
        ### Get Hotspots ###
        ####################

        self.hotspot_0idx = iu.get_idx0_hotspots(
            self.mappings, self.ppi_conf, self.binderlen
        )

        #####################################
        ### Initialise Potentials Manager ###
        #####################################

        self.potential_manager = PotentialManager(
            self.potential_conf,
            self.ppi_conf,
            self.diffuser_conf,
            self.inf_conf,
            self.hotspot_0idx,
            self.binderlen,
        )

        ###################################
        ### Initialize other attributes ###
        ###################################

        xyz_27 = self.target_feats["xyz_27"]
        mask_27 = self.target_feats["mask_27"]
        seq_orig = self.target_feats["seq"]
        L_mapped = len(self.contig_map.ref)
        contig_map = self.contig_map

        self.diffusion_mask = self.mask_str
        length_bound = self.contig_map.sampled_mask_length_bound.copy()

        first_res = 0
        self.chain_idx = []
        self.idx_pdb = []
        all_chains = {contig_ref[0] for contig_ref in self.contig_map.ref}
        available_chains = sorted(list(set(string.ascii_letters) - all_chains))

        # Iterate over each chain
        for last_res in length_bound:
            chain_ids = {
                contig_ref[0] for contig_ref in self.contig_map.ref[first_res:last_res]
            }
            # If we are designing this chain, it will have a '-' in the contig map
            # Renumber this chain from 1
            if "_" in chain_ids:
                self.idx_pdb += [idx + 1 for idx in range(last_res - first_res)]
                chain_ids = chain_ids - {"_"}
                # If there are no fixed residues that have a chain id, pick the first available letter
                if not chain_ids:
                    if not available_chains:
                        raise ValueError(
                            f"No available chains!  You are trying to design a new chain, and you have "
                            f"already used all upper- and lower-case chain ids (up to 52 chains): "
                            f"{','.join(all_chains)}."
                        )
                    chain_id = available_chains[0]
                    available_chains.remove(chain_id)
                # Otherwise, use the chain of the fixed (motif) residues
                else:
                    assert (
                        len(chain_ids) == 1
                    ), f"Error: Multiple chain IDs in chain: {chain_ids}"
                    chain_id = list(chain_ids)[0]
                self.chain_idx += [chain_id] * (last_res - first_res)
            # If this is a fixed chain, maintain the chain and residue numbering
            else:
                self.idx_pdb += [
                    contig_ref[1]
                    for contig_ref in self.contig_map.ref[first_res:last_res]
                ]
                assert (
                    len(chain_ids) == 1
                ), f"Error: Multiple chain IDs in chain: {chain_ids}"
                self.chain_idx += [list(chain_ids)[0]] * (last_res - first_res)
            first_res = last_res

        ####################################
        ### Generate initial coordinates ###
        ####################################

        if self.diffuser_conf.partial_T:
            assert (
                xyz_27.shape[0] == L_mapped
            ), f"there must be a coordinate in the input PDB for \
                    each residue implied by the contig string for partial diffusion.  length of \
                    input PDB != length of contig string: {xyz_27.shape[0]} != {L_mapped}"
            assert (
                contig_map.hal_idx0 == contig_map.ref_idx0
            ), f"for partial diffusion there can \
                    be no offset between the index of a residue in the input and the index of the \
                    residue in the output, {contig_map.hal_idx0} != {contig_map.ref_idx0}"
            # Partially diffusing from a known structure
            xyz_mapped = xyz_27
            atom_mask_mapped = mask_27
        else:
            # Fully diffusing from points initialised at the origin
            # adjust size of input xt according to residue map
            xyz_mapped = ms.mint.full((1, 1, L_mapped, 27, 3), np.nan)
            xyz_mapped[:, :, contig_map.hal_idx0, ...] = xyz_27[
                contig_map.ref_idx0, ...
            ]
            xyz_motif_prealign = xyz_mapped.clone()
            motif_prealign_com = xyz_motif_prealign[0, 0, :, 1].mean(dim=0)
            self.motif_com = xyz_27[contig_map.ref_idx0, 1].mean(dim=0)
            xyz_mapped = get_init_xyz(xyz_mapped).squeeze()
            # adjust the size of the input atom map
            atom_mask_mapped = ms.mint.full((L_mapped, 27), False)
            atom_mask_mapped[contig_map.hal_idx0] = mask_27[contig_map.ref_idx0]

        # Diffuse the contig-mapped coordinates
        if self.diffuser_conf.partial_T:
            assert (
                self.diffuser_conf.partial_T <= self.diffuser_conf.T
            ), "Partial_T must be less than T"
            self.t_step_input = int(self.diffuser_conf.partial_T)
        else:
            self.t_step_input = int(self.diffuser_conf.T)
        t_list = np.arange(1, self.t_step_input + 1)

        #################################
        ### Generate initial sequence ###
        #################################

        # 21 is the mask token
        seq_t = ms.mint.full((1, L_mapped), 21).squeeze()
        seq_t[contig_map.hal_idx0] = seq_orig[contig_map.ref_idx0]

        # Unmask sequence if desired
        if self._conf.contigmap.provide_seq is not None:
            seq_t[self.mask_seq.squeeze()] = seq_orig[self.mask_seq.squeeze()]

        seq_t[~self.mask_seq.squeeze()] = 21
        seq_t = ms.mint.nn.functional.one_hot(seq_t, num_classes=22).float()  # [L,22]
        seq_orig = ms.mint.nn.functional.one_hot(
            seq_orig, num_classes=22
        ).float()  # [L,22]

        fa_stack, xyz_true = self.diffuser.diffuse_pose(
            xyz_mapped,
            ms.mint.clone(seq_t),
            atom_mask_mapped.squeeze(),
            diffusion_mask=self.diffusion_mask.squeeze(),
            t_list=t_list,
        )
        xT = fa_stack[-1].squeeze()[:, :14, :]
        xt = ms.mint.clone(xT)

        self.denoiser = self.construct_denoiser(
            len(self.contig_map.ref), visible=self.mask_seq.squeeze()
        )

        ######################
        ### Apply Symmetry ###
        ######################

        if self.symmetry is not None:
            xt, seq_t = self.symmetry.apply_symmetry(xt, seq_t)
        self._log.info(f"Sequence init: {seq2chars(ms.mint.argmax(seq_t, dim=-1))}")

        self.msa_prev = None
        self.pair_prev = None
        self.state_prev = None

        #########################################
        ### Parse ligand for ligand potential ###
        #########################################

        if self.potential_conf.guiding_potentials is not None:
            if any(
                list(
                    filter(
                        lambda x: "substrate_contacts" in x,
                        self.potential_conf.guiding_potentials,
                    )
                )
            ):
                assert (
                    len(self.target_feats["xyz_het"]) > 0
                ), "If you're using the Substrate Contact potential, \
                        you need to make sure there's a ligand in the input_pdb file!"
                het_names = np.array(
                    [i["name"].strip() for i in self.target_feats["info_het"]]
                )
                xyz_het = self.target_feats["xyz_het"][
                    het_names == self._conf.potentials.substrate
                ]
                xyz_het = ms.from_numpy(xyz_het)
                assert (
                    xyz_het.shape[0] > 0
                ), f"expected >0 heteroatoms from ligand with name {self._conf.potentials.substrate}"
                xyz_motif_prealign = xyz_motif_prealign[0, 0][
                    self.diffusion_mask.squeeze()
                ]
                motif_prealign_com = xyz_motif_prealign[:, 1].mean(dim=0)
                xyz_het_com = xyz_het.mean(dim=0)
                for pot in self.potential_manager.potentials_to_apply:
                    pot.motif_substrate_atoms = xyz_het
                    pot.diffusion_mask = self.diffusion_mask.squeeze()
                    pot.xyz_motif = xyz_motif_prealign
                    pot.diffuser = self.diffuser
        return xt, seq_t

    def _preprocess(self, seq, xyz_t, t, repack=False):
        """
        Function to prepare inputs to diffusion model

            seq (L,22) one-hot sequence

            msa_masked (1,1,L,48)

            msa_full (1,1,L,25)

            xyz_t (L,14,3) template crds (diffused)

            t1d (1,L,28) this is the t1d before tacking on the chi angles:
                - seq + unknown/mask (21)
                - global timestep (1-t/T if not motif else 1) (1)

                MODEL SPECIFIC:
                - contacting residues: for ppi. Target residues in contact with binder (1)
                - empty feature (legacy) (1)
                - ss (H, E, L, MASK) (4)

            t2d (1, L, L, 45)
                - last plane is block adjacency
        """

        L = seq.shape[0]

        ##################
        ### msa_masked ###
        ##################
        msa_masked = ms.mint.zeros((1, 1, L, 48))
        msa_masked[:, :, :, :22] = seq[None, None]
        msa_masked[:, :, :, 22:44] = seq[None, None]
        msa_masked[:, :, 0, 46] = 1.0
        msa_masked[:, :, -1, 47] = 1.0

        ################
        ### msa_full ###
        ################
        msa_full = ms.mint.zeros((1, 1, L, 25))
        msa_full[:, :, :, :22] = seq[None, None]
        msa_full[:, :, 0, 23] = 1.0
        msa_full[:, :, -1, 24] = 1.0

        ###########
        ### t1d ###
        ###########

        # Here we need to go from one hot with 22 classes to one hot with 21 classes (last plane is missing token)
        t1d = ms.mint.zeros((1, 1, L, 21))

        seqt1d = ms.mint.clone(seq)
        for idx in range(L):
            if seqt1d[idx, 21] == 1:
                seqt1d[idx, 20] = 1
                seqt1d[idx, 21] = 0

        t1d[:, :, :, :21] = seqt1d[None, None, :, :21]

        # Set timestep feature to 1 where diffusion mask is True, else 1-t/T
        timefeature = ms.mint.zeros((L)).float()
        timefeature[self.mask_str.squeeze()] = 1
        timefeature[~self.mask_str.squeeze()] = 1 - t / self.T
        timefeature = timefeature[None, None, ..., None]

        t1d = ms.mint.cat((t1d, timefeature), dim=-1).float()

        #############
        ### xyz_t ###
        #############
        if self.preprocess_conf.sidechain_input:
            xyz_t[ms.mint.where(seq == 21, True, False), 3:, :] = float("nan")
        else:
            xyz_t[~self.mask_str.squeeze(), 3:, :] = float("nan")

        xyz_t = xyz_t[None, None]

        xyz_t = ms.mint.cat(
            (xyz_t, ms.mint.full((1, 1, L, 13, 3), float("nan"))), dim=3
        )

        ###########
        ### t2d ###
        ###########
        t2d = xyz_to_t2d(xyz_t)

        ###########
        ### idx ###
        ###########
        idx = ms.tensor(self.contig_map.rf)[None]

        ###############
        ### alpha_t ###
        ###############
        seq_tmp = t1d[..., :-1].argmax(dim=-1).reshape(-1, L)
        alpha, _, alpha_mask, _ = util.get_torsions(
            xyz_t.reshape(-1, L, 27, 3), seq_tmp, TOR_INDICES, TOR_CAN_FLIP, REF_ANGLES
        )
        alpha_mask = ms.mint.logical_and(alpha_mask, ~ms.mint.isnan(alpha[..., 0]))
        alpha[ms.mint.isnan(alpha)] = 0.0
        alpha = alpha.reshape(1, -1, L, 10, 2)
        alpha_mask = alpha_mask.reshape(1, -1, L, 10, 1)
        alpha_t = ms.mint.cat((alpha, alpha_mask.to(dtype=ms.float32)), dim=-1).reshape(
            1, -1, L, 30
        )

        ######################
        ### added_features ###
        ######################
        if self.preprocess_conf.d_t1d >= 24:  # add hotspot residues
            hotspot_tens = ms.mint.zeros(L).float()
            if self.ppi_conf.hotspot_res is None:
                print(
                    "WARNING: you're using a model trained on complexes and hotspot residues, without specifying hotspots.\
                         If you're doing monomer diffusion this is fine"
                )
                hotspot_idx = []
            else:
                hotspots = [(i[0], int(i[1:])) for i in self.ppi_conf.hotspot_res]
                hotspot_idx = []
                for i, res in enumerate(self.contig_map.con_ref_pdb_idx):
                    if res in hotspots:
                        hotspot_idx.append(self.contig_map.hal_idx0[i])
                hotspot_tens[hotspot_idx] = 1.0

            # Add blank (legacy) feature and hotspot tensor
            t1d = ms.mint.cat(
                (
                    t1d,
                    ms.mint.zeros_like(t1d[..., :1]),
                    hotspot_tens[None, None, ..., None],
                ),
                dim=-1,
            )

        return (
            msa_masked,
            msa_full,
            seq[None],
            ms.mint.squeeze(xyz_t, dim=0),
            idx,
            t1d,
            t2d,
            xyz_t,
            alpha_t,
        )

    def sample_step(self, *, t, x_t, seq_init, final_step):
        """Generate the next pose that the model should be supplied at timestep t-1.

        Args:
            t (int): The timestep that has just been predicted
            seq_t (mindspore.tensor): (L,22) The sequence at the beginning of this timestep
            x_t (mindspore.tensor): (L,14,3) The residue positions at the beginning of this timestep
            seq_init (mindspore.tensor): (L,22) The initialized sequence used in updating the sequence.

        Returns:
            px0: (L,14,3) The model's prediction of x0.
            x_t_1: (L,14,3) The updated positions of the next step.
            seq_t_1: (L,22) The updated sequence of the next step.
            tors_t_1: (L, ?) The updated torsion angles of the next  step.
            plddt: (L, 1) Predicted lDDT of x0.
        """
        msa_masked, msa_full, seq_in, xt_in, idx_pdb, t1d, t2d, xyz_t, alpha_t = (
            self._preprocess(seq_init, x_t, t)
        )

        N, L = msa_masked.shape[:2]

        if self.symmetry is not None:
            idx_pdb, self.chain_idx = self.symmetry.res_idx_procesing(res_idx=idx_pdb)

        msa_prev = None
        pair_prev = None
        state_prev = None

        msa_prev, pair_prev, px0, state_prev, alpha, logits, plddt = self.model(
            msa_masked,
            msa_full,
            seq_in,
            xt_in,
            idx_pdb,
            t1d=t1d,
            t2d=t2d,
            xyz_t=xyz_t,
            alpha_t=alpha_t,
            msa_prev=msa_prev,
            pair_prev=pair_prev,
            state_prev=state_prev,
            t=ms.tensor(t),
            return_infer=True,
            motif_mask=self.diffusion_mask.squeeze(),
        )

        # prediction of X0
        _, px0 = self.allatom(ms.mint.argmax(seq_in, dim=-1), px0, alpha)
        px0 = px0.squeeze()[:, :14]

        #####################
        ### Get next pose ###
        #####################

        if t > final_step:
            seq_t_1 = ms.mint.one_hot(seq_init, num_classes=22)
            x_t_1, px0 = self.denoiser.get_next_pose(
                xt=x_t,
                px0=px0,
                t=t,
                diffusion_mask=self.mask_str.squeeze(),
                align_motif=self.inf_conf.align_motif,
            )
        else:
            x_t_1 = ms.mint.clone(px0)
            seq_t_1 = ms.mint.clone(seq_init)

        if self.symmetry is not None:
            x_t_1, seq_t_1 = self.symmetry.apply_symmetry(x_t_1, seq_t_1)

        return px0, x_t_1, seq_t_1, plddt


class SelfConditioning(Sampler):
    """
    Model Runner for self conditioning
    pX0[t+1] is provided as a template input to the model at time t
    """

    def sample_step(self, *, t, x_t, seq_init, final_step):
        """
        Generate the next pose that the model should be supplied at timestep t-1.
        Args:
            t (int): The timestep that has just been predicted
            seq_t (mindspore.tensor): (L,22) The sequence at the beginning of this timestep
            x_t (mindspore.tensor): (L,14,3) The residue positions at the beginning of this timestep
            seq_init (mindspore.tensor): (L,22) The initialized sequence used in updating the sequence.
        Returns:
            px0: (L,14,3) The model's prediction of x0.
            x_t_1: (L,14,3) The updated positions of the next step.
            seq_t_1: (L) The sequence to the next step (== seq_init)
            plddt: (L, 1) Predicted lDDT of x0.
        """

        msa_masked, msa_full, seq_in, xt_in, idx_pdb, t1d, t2d, xyz_t, alpha_t = (
            self._preprocess(seq_init, x_t, t)
        )
        B, N, L = xyz_t.shape[:3]

        ##################################
        ######## Str Self Cond ###########
        ##################################
        if (t < self.diffuser.T) and (t != self.diffuser_conf.partial_T):
            zeros = ms.mint.zeros((B, 1, L, 24, 3)).float()
            xyz_t = ms.mint.cat(
                (self.prev_pred.unsqueeze(1), zeros), dim=-2
            )  # [B,T,L,27,3]
            t2d_44 = xyz_to_t2d(xyz_t)  # [B,T,L,L,44]
        else:
            xyz_t = ms.mint.zeros_like(xyz_t)
            t2d_44 = ms.mint.zeros_like(t2d[..., :44])
        # No effect if t2d is only dim 44
        t2d[..., :44] = t2d_44

        if self.symmetry is not None:
            idx_pdb, self.chain_idx = self.symmetry.res_idx_procesing(res_idx=idx_pdb)

        ####################
        ### Forward Pass ###
        ####################

        msa_prev, pair_prev, px0, state_prev, alpha, logits, plddt = self.model(
            msa_masked,
            msa_full,
            seq_in,
            xt_in,
            idx_pdb,
            t1d=t1d,
            t2d=t2d,
            xyz_t=xyz_t,
            alpha_t=alpha_t,
            msa_prev=None,
            pair_prev=None,
            state_prev=None,
            t=ms.tensor(t),
            return_infer=True,
            motif_mask=self.diffusion_mask.squeeze(),
            cyclic_reses=self.cyclic_reses,
        )

        if self.symmetry is not None and self.inf_conf.symmetric_self_cond:
            px0 = self.symmetrise_prev_pred(px0=px0, seq_in=seq_in, alpha=alpha)[
                :, :, :3
            ]

        self.prev_pred = ms.mint.clone(px0)

        # prediction of X0
        _, px0 = self.allatom(ms.mint.argmax(seq_in, dim=-1), px0, alpha)
        px0 = px0.squeeze()[:, :14]

        ###########################
        ### Generate Next Input ###
        ###########################

        seq_t_1 = ms.mint.clone(seq_init)
        if t > final_step:
            x_t_1, px0 = self.denoiser.get_next_pose(
                xt=x_t,
                px0=px0,
                t=t,
                diffusion_mask=self.mask_str.squeeze(),
                align_motif=self.inf_conf.align_motif,
                include_motif_sidechains=self.preprocess_conf.motif_sidechain_input,
            )
            self._log.info(
                f"Timestep {t}, input to next step: { seq2chars(ms.mint.argmax(seq_t_1, dim=-1).tolist())}"
            )
        else:
            x_t_1 = ms.mint.clone(px0)

        ######################
        ### Apply symmetry ###
        ######################

        if self.symmetry is not None:
            x_t_1, seq_t_1 = self.symmetry.apply_symmetry(x_t_1, seq_t_1)

        return px0, x_t_1, seq_t_1, plddt

    def symmetrise_prev_pred(self, px0, seq_in, alpha):
        """
        Method for symmetrising px0 output for self-conditioning
        """
        _, px0_aa = self.allatom(ms.mint.argmax(seq_in, dim=-1), px0, alpha)
        px0_sym, _ = self.symmetry.apply_symmetry(
            px0_aa.squeeze()[:, :14], ms.mint.argmax(seq_in, dim=-1).squeeze()
        )
        px0_sym = px0_sym[None]
        return px0_sym


class ScaffoldedSampler(SelfConditioning):
    """
    Model Runner for Scaffold-Constrained diffusion
    """

    def __init__(self, conf: DictConfig):
        """
        Initialize scaffolded sampler.
        Two basic approaches here:
            i) Given a block adjacency/secondary structure input, generate a fold (in the presence or absence of a target)
                - This allows easy generation of binders or specific folds
                - Allows simple expansion of an input, to sample different lengths
            ii) Providing a contig input and corresponding block adjacency/secondary structure input
                - This allows mixed motif scaffolding and fold-conditioning.
                - Adjacency/secondary structure inputs must correspond exactly in length to the contig string
        """
        super().__init__(conf)
        # initialize BlockAdjacency sampling class
        if conf.scaffoldguided.scaffold_dir is None:
            assert any(
                x is not None
                for x in (
                    conf.contigmap.inpaint_str_helix,
                    conf.contigmap.inpaint_str_strand,
                    conf.contigmap.inpaint_str_loop,
                )
            )
            if conf.contigmap.inpaint_str_loop is not None:
                assert (
                    conf.scaffoldguided.mask_loops == False
                ), "You shouldn't be masking loops if you're specifying loop secondary structure"
        else:
            # initialize BlockAdjacency sampling class
            assert all(
                x is None
                for x in (
                    conf.contigmap.inpaint_str_helix,
                    conf.contigmap.inpaint_str_strand,
                    conf.contigmap.inpaint_str_loop,
                )
            ), "can't provide scaffold_dir if you're also specifying per-residue ss"
            self.blockadjacency = iu.BlockAdjacency(conf, conf.inference.num_designs)

        #################################################
        ### Initialize target, if doing binder design ###
        #################################################

        if conf.scaffoldguided.target_pdb:
            self.target = iu.Target(conf.scaffoldguided, conf.ppi.hotspot_res)
            self.target_pdb = self.target.get_target()
            if conf.scaffoldguided.target_ss is not None:
                with open(conf.scaffoldguided.target_ss, "rb") as f:
                    self.target_ss = pickle.load(f).long()
                self.target_ss = ms.mint.nn.functional.one_hot(
                    self.target_ss, num_classes=4
                )
                if self._conf.scaffoldguided.contig_crop is not None:
                    self.target_ss = self.target_ss[self.target_pdb["crop_mask"]]
            if conf.scaffoldguided.target_adj is not None:
                with open(conf.scaffoldguided.target_adj, "rb") as f:
                    self.target_adj = pickle.load(f).long()
                self.target_adj = ms.mint.nn.functional.one_hot(
                    self.target_adj, num_classes=3
                )
                if self._conf.scaffoldguided.contig_crop is not None:
                    self.target_adj = self.target_adj[self.target_pdb["crop_mask"]]
                    self.target_adj = self.target_adj[:, self.target_pdb["crop_mask"]]
        else:
            self.target = None
            self.target_pdb = False

    def sample_init(self):
        """
        Wrapper method for taking secondary structure + adj, and outputting xt, seq_t
        """

        ##########################
        ### Process Fold Input ###
        ##########################
        if hasattr(self, "blockadjacency"):
            self.L, self.ss, self.adj = self.blockadjacency.get_scaffold()
            self.adj = ms.mint.nn.functional.one_hot(self.adj.long(), num_classes=3)
        else:
            self.L = 100

        ##############################
        ### Auto-contig generation ###
        ##############################

        if self.contig_conf.contigs is None:
            # process target
            xT = ms.mint.full((self.L, 27, 3), np.nan)
            xT = get_init_xyz(xT[None, None]).squeeze()
            seq_T = ms.mint.full((self.L,), 21)
            self.diffusion_mask = ms.mint.full((self.L,), False)
            atom_mask = ms.mint.full((self.L, 27), False)
            self.binderlen = self.L

            if self.target:
                target_L = np.shape(self.target_pdb["xyz"])[0]
                # xyz
                target_xyz = ms.mint.full((target_L, 27, 3), np.nan)
                target_xyz[:, :14, :] = ms.from_numpy(self.target_pdb["xyz"])
                xT = ms.mint.cat((xT, target_xyz), dim=0)
                # seq
                seq_T = ms.mint.cat(
                    (seq_T, ms.from_numpy(self.target_pdb["seq"])), dim=0
                )
                # diffusion mask
                self.diffusion_mask = ms.mint.cat(
                    (self.diffusion_mask, ms.mint.full((target_L,), True)), dim=0
                )
                # atom mask
                mask_27 = ms.mint.full((target_L, 27), False)
                mask_27[:, :14] = ms.from_numpy(self.target_pdb["mask"])
                atom_mask = ms.mint.cat((atom_mask, mask_27), dim=0)
                self.L += target_L
                # generate contigmap object
                contig = []
                for idx, i in enumerate(self.target_pdb["pdb_idx"][:-1]):
                    if idx == 0:
                        start = i[1]
                    if (
                        i[1] + 1 != self.target_pdb["pdb_idx"][idx + 1][1]
                        or i[0] != self.target_pdb["pdb_idx"][idx + 1][0]
                    ):
                        contig.append(f"{i[0]}{start}-{i[1]}/0 ")
                        start = self.target_pdb["pdb_idx"][idx + 1][1]
                contig.append(
                    f"{self.target_pdb['pdb_idx'][-1][0]}{start}-{self.target_pdb['pdb_idx'][-1][1]}/0 "
                )
                contig.append(f"{self.binderlen}-{self.binderlen}")
                contig = ["".join(contig)]
            else:
                contig = [f"{self.binderlen}-{self.binderlen}"]
            self.contig_map = ContigMap(self.target_pdb, contig)
            self.mappings = self.contig_map.get_mappings()
            self.mask_seq = self.diffusion_mask
            self.mask_str = self.diffusion_mask
            L_mapped = len(self.contig_map.ref)

        ############################
        ### Specific Contig mode ###
        ############################

        else:
            # get contigmap from command line
            assert (
                self.target is None
            ), "Giving a target is the wrong way of handling this is you're doing contigs and secondary structure"

            # process target and reinitialise potential_manager. This is here because the 'target' is always set up to be the second chain in out inputs.
            self.target_feats = iu.process_target(self.inf_conf.input_pdb)
            self.contig_map = self.construct_contig(self.target_feats)
            self.mappings = self.contig_map.get_mappings()
            self.mask_seq = ms.from_numpy(self.contig_map.inpaint_seq)[None, :]
            self.mask_str = ms.from_numpy(self.contig_map.inpaint_str)[None, :]
            self.binderlen = len(self.contig_map.inpaint)
            self.L = len(self.contig_map.inpaint_seq)
            target_feats = self.target_feats
            contig_map = self.contig_map

            xyz_27 = target_feats["xyz_27"]
            mask_27 = target_feats["mask_27"]
            seq_orig = target_feats["seq"]
            L_mapped = len(self.contig_map.ref)
            seq_T = ms.mint.full((L_mapped,), 21)
            seq_T[contig_map.hal_idx0] = seq_orig[contig_map.ref_idx0]
            seq_T[~self.mask_seq.squeeze()] = 21

            diffusion_mask = self.mask_str
            self.diffusion_mask = diffusion_mask

            xT = ms.mint.full((1, 1, L_mapped, 27, 3), np.nan)
            xT[:, :, contig_map.hal_idx0, ...] = xyz_27[contig_map.ref_idx0, ...]
            xT = get_init_xyz(xT).squeeze()
            atom_mask = ms.mint.full((L_mapped, 27), False)
            atom_mask[contig_map.hal_idx0] = mask_27[contig_map.ref_idx0]

            if hasattr(self.contig_map, "ss_spec"):
                self.adj = ms.mint.full((L_mapped, L_mapped), 2)  # masked
                self.adj = ms.mint.nn.functional.one_hot(self.adj.long(), num_classes=3)
                self.ss = iu.ss_from_contig(self.contig_map.ss_spec)
            assert L_mapped == self.adj.shape[0]

        #######################################
        ### Resolve cyclic peptide indices ###
        #######################################
        if self._conf.inference.cyclic:
            if self._conf.inference.cyc_chains is None:
                # default to all residues being cyclized
                self.cyclic_reses = ~self.mask_str.squeeze()
            else:
                # use cyc_chains arg to determine cyclic_reses mask
                assert (
                    type(self._conf.inference.cyc_chains) is str
                ), "cyc_chains arg must be string"
                cyc_chains = self._conf.inference.cyc_chains
                cyc_chains = [i.upper() for i in cyc_chains]
                hal_idx = (
                    self.contig_map.hal
                )  # the pdb indices of output, knowledge of different chains
                is_cyclized = (
                    ms.mint.zeros_like(self.mask_str).bool().squeeze()
                )  # initially empty

                for ch in cyc_chains:
                    ch_mask = ms.tensor([idx[0] == ch for idx in hal_idx]).bool()
                    # set this whole chain to be cyclic
                    is_cyclized[ch_mask] = True
                self.cyclic_reses = is_cyclized
        else:
            self.cyclic_reses = ms.mint.zeros_like(self.mask_str).bool().squeeze()

        ####################
        ### Get hotspots ###
        ####################
        self.hotspot_0idx = iu.get_idx0_hotspots(
            self.mappings, self.ppi_conf, self.binderlen
        )

        #########################
        ### Set up potentials ###
        #########################

        self.potential_manager = PotentialManager(
            self.potential_conf,
            self.ppi_conf,
            self.diffuser_conf,
            self.inf_conf,
            self.hotspot_0idx,
            self.binderlen,
        )

        self.chain_idx = ["A" if i < self.binderlen else "B" for i in range(self.L)]

        ########################
        ### Handle Partial T ###
        ########################

        if self.diffuser_conf.partial_T:
            assert self.diffuser_conf.partial_T <= self.diffuser_conf.T
            self.t_step_input = int(self.diffuser_conf.partial_T)
        else:
            self.t_step_input = int(self.diffuser_conf.T)
        t_list = np.arange(1, self.t_step_input + 1)
        seq_T = ms.mint.nn.functional.one_hot(seq_T, num_classes=22).float()

        fa_stack, xyz_true = self.diffuser.diffuse_pose(
            xT,
            ms.mint.clone(seq_T),
            atom_mask.squeeze(),
            diffusion_mask=self.diffusion_mask.squeeze(),
            t_list=t_list,
            include_motif_sidechains=self.preprocess_conf.motif_sidechain_input,
        )

        #######################
        ### Set up Denoiser ###
        #######################

        self.denoiser = self.construct_denoiser(self.L, visible=self.mask_seq.squeeze())

        xT = ms.mint.clone(fa_stack[-1].squeeze()[:, :14, :])
        return xT, seq_T

    def _preprocess(self, seq, xyz_t, t):
        msa_masked, msa_full, seq, xyz_prev, idx_pdb, t1d, t2d, xyz_t, alpha_t = (
            super()._preprocess(seq, xyz_t, t, repack=False)
        )

        ###################################
        ### Add Adj/Secondary Structure ###
        ###################################

        assert (
            self.preprocess_conf.d_t1d == 28
        ), "The checkpoint you're using hasn't been trained with sec-struc/block adjacency features"
        assert (
            self.preprocess_conf.d_t2d == 47
        ), "The checkpoint you're using hasn't been trained with sec-struc/block adjacency features"

        #####################
        ### Handle Target ###
        #####################

        if self.target:
            blank_ss = ms.mint.nn.functional.one_hot(
                ms.mint.full((self.L - self.binderlen,), 3), num_classes=4
            )
            full_ss = ms.mint.cat((self.ss, blank_ss), dim=0)
            if self._conf.scaffoldguided.target_ss is not None:
                full_ss[self.binderlen :] = self.target_ss
        else:
            full_ss = self.ss
        t1d = ms.mint.cat((t1d, full_ss[None, None].to(dtype=ms.float32)), dim=-1)

        ###########
        ### t2d ###
        ###########

        if self.d_t2d == 47:
            if self.target:
                full_adj = ms.mint.zeros((self.L, self.L, 3))
                full_adj[:, :, -1] = 1.0  # set to mask
                full_adj[: self.binderlen, : self.binderlen] = self.adj
                if self._conf.scaffoldguided.target_adj is not None:
                    full_adj[self.binderlen :, self.binderlen :] = self.target_adj
            else:
                full_adj = self.adj
            t2d = ms.mint.cat((t2d, full_adj[None, None].to(dtype=ms.float32)), dim=-1)

        ###########
        ### idx ###
        ###########

        if self.target:
            idx_pdb[:, self.binderlen :] += 200

        return msa_masked, msa_full, seq, xyz_prev, idx_pdb, t1d, t2d, xyz_t, alpha_t


class AbSampler(Sampler):
    """ """

    def ab_design(self):
        return True

    def sample_init(self):
        """
        We should do some autodetection of Ab chain in this function. The chain which we are designing is the Ab chain
        We should also enforce the H,L,T chain labelling
        """

        # 1) Parse pdb to an ab_pose that can be easily manipulated
        ####################################################################
        self.pose = ab_pose.AbPose()

        # Determine which format the input structure has been provided
        if self.inf_conf.input_pdb is not None:
            assert (
                self.ab_conf.target_pdb is None and self.ab_conf.framework_pdb is None
            ), "Both inference.input_pdb and antibody.target + antibody.framework_pdb cannot be active at the same time."

            self.pose.from_HLT(self.inf_conf.input_pdb)

        assert not (
            (self.ab_conf.target_pdb is None) ^ (self.ab_conf.framework_pdb is None)
        ), "Having antibody.target and not antibody.framework_pdb or vice versa is not currently supported."

        if (
            self.ab_conf.target_pdb is not None
            and self.ab_conf.framework_pdb is not None
        ):
            assert (
                self.diffuser_conf.partial_T is None
            ), "Partial diffusion is only supported when using inference.input_pdb"

            assert (
                self.inf_conf.input_pdb is None
            ), "Both inference.input_pdb and antibody.target + antibody.framework_pdb cannot be active at the same time."

            self.pose.framework_from_HLT(self.ab_conf.framework_pdb)
            self.pose.target_from_HLT(self.ab_conf.target_pdb)

        # 2) Adjust the length of the CDR loops in the AbPose
        ####################################################################################################

        # If we are doing partial diffusion, we will skip the loop length adjustment step
        # Since we are just resampling from the starting scaffold
        if self.diffuser_conf.partial_T:
            print("Partial diffusion detected, skipping loop length adjustment step")

        else:
            # We are doing full diffusion, so we need to adjust the loop lengths
            self.pose.adjust_loop_lengths(self.ab_conf.design_loops)

        # 3) Assemble the ab_item for use downstream. Also determine which residues we should design
        ####################################################################################################
        self.ab_item = util.Dotdict()
        self.ab_item.loop_mask = self.pose.parse_design_mask(self.ab_conf.design_loops)

        assert ms.mint.any(
            self.ab_item.loop_mask
        ), "Received input with no diffused region. Exiting"

        self.L = self.pose.length()
        self.diffusion_mask = ms.mint.ones((self.L,)).bool()  # True is not diffused
        if self.ab_conf.T_scheme == "fixed_dock":
            # Only diffuse loops in fixed dock case
            self.diffusion_mask[self.ab_item.loop_mask] = False

        elif self.ab_conf.T_scheme in {
            "single_T",
            "single_T_better_confidence",
            "no_T",
            "single_T_correct_selfcond",
            "noSeq_single_T_correct_selfcond",
            "selfcond_emb",
        }:
            # Diffuse binder region
            self.diffusion_mask[: self.pose.binder_len()] = False

        else:
            raise NotImplementedError()

        # Needed for downstream processing
        self.ab_item.target = True

        self.ab_item.target_mask = ms.mint.zeros((self.L,)).bool()
        self.ab_item.target_mask[self.pose.binder_len() :] = True

        self.ab_item.interchain_mask = self.pose.get_interchain_mask()

        self.binderlen = self.pose.binder_len()

        # 4) Parse hotspots
        ##########################################
        self.ab_item.hotspots = self.pose.parse_hotspots(self.ppi_conf.hotspot_res)

        self.ab_item.inputs = self.pose.to_diffusion_inputs()

        # Assign chain_idx, this is used by run_inference.py to assign chain letters upon writing to disk
        self.chain_idx = self.pose.get_chain_idx()
        self.loop_map = self.pose.get_loop_map()

        # 5) Setup Potential Manager, there are not yet any Ab potentials
        ########################################################################
        self.potential_manager = PotentialManager(
            self.potential_conf,
            self.ppi_conf,
            self.diffuser_conf,
            self.inf_conf,
            self.ab_item.hotspots,
            self.pose.binder_len(),
        )

        # These are necessary for compatibility with the parent sample_step function
        self.mask_seq = ms.mint.clone(~self.ab_item.loop_mask)
        # self.mask_seq = ms.mint.clone(self.diffusion_mask)
        self.mask_str = ms.mint.clone(self.diffusion_mask)

        # Determine the timesteps to use for diffusion
        if self.diffuser_conf.partial_T:
            assert self.diffuser_conf.partial_T <= self.diffuser_conf.T
            self.t_step_input = int(self.diffuser_conf.partial_T)
        else:
            self.t_step_input = int(self.diffuser_conf.T)

        t_list = np.arange(1, self.t_step_input + 1)

        # 6) Diffuse the contig mapped regions
        #############################################
        fa_stack, _ = self.diffuser.diffuse_pose(
            self.ab_item.inputs.xyz_true,
            self.ab_item.inputs.seq_true,
            self.ab_item.inputs.atom_mask,
            diffusion_mask=self.diffusion_mask.squeeze(),
            t_list=t_list,
            # diffuse_sidechains=self.preprocess_conf.sidechain_input,
            include_motif_sidechains=self.preprocess_conf.motif_sidechain_input,
        )

        xT = ms.mint.clone(fa_stack[-1].squeeze()[:, :14])

        # 7) Mask the input sequence of the CDR loops
        ####################################################
        seq_T = ms.mint.nn.functional.one_hot(
            self.ab_item.inputs.seq_true, num_classes=22
        ).float()
        seq_T[~self.mask_seq, :20] = 0
        seq_T[~self.mask_seq, 21] = 1  # Mask token

        self.denoiser = self.construct_denoiser(self.L, visible=self.diffusion_mask)

        return xT, seq_T

    def _preprocess(self, seq, xyz_t, t):
        """
        xyz_t [T,L,27,3]
        msa_masked, msa_full, seq[None], mindspore.mint.squeeze(xyz_t, dim=0), idx, t1d, t2d, xyz_t, alpha_t

        """

        L = xyz_t.shape[0]

        # 1) Generate the time-dependent features
        ################################################

        tmp_xyz = ms.mint.full((L, 27, 3), np.nan)
        tmp_xyz[:, :14] = xyz_t  # [L,27,3]

        # Featurize expects xyz to have 27 atom dimensions
        features = featurize(
            self.ab_item,
            seq,
            tmp_xyz,
            self.preprocess_conf.d_t1d,
            self.preprocess_conf.hotspot_dim,
            self.ab_conf.T_scheme,
            1 - (t / self.T),
            ~self.preprocess_conf.motif_sidechain_input,
            ~self.ab_conf.no_bugfix_t1d_mask,
        )

        # 2) Now generate the time-invariant features
        ################################################

        ## idx_pdb ##
        #############

        idx_pdb = ms.mint.arange(L)  # (L)
        if self.ab_item.target:
            # Do idx jump at chainbreak
            idx_pdb[self.ab_item.target_mask] += 200

        ## Add hotspots to t1d ##
        #########################
        features["t1d"][..., 22] = self.ab_item.hotspots[None, None]

        retval = (
            features["msa_masked"],
            features["msa_full"],
            features["seq"],
            features["xyz_prev"].unsqueeze(0),
            idx_pdb.unsqueeze(0),
            features["t1d"].unsqueeze(0),
            features["t2d"].unsqueeze(0),
            features["xyz_t"].unsqueeze(0),
            features["alpha_t"].unsqueeze(0),
        )

        return retval

    def sample_step(self, *, t, x_t, seq_init, final_step):

        msa_masked, msa_full, seq_in, xt_in, idx_pdb, t1d, t2d, xyz_t, alpha_t = (
            self._preprocess(seq_init, x_t, t)
        )

        B, N, L = xyz_t.shape[:3]

        ##################################
        ######## Seq Self Cond ###########
        ##################################
        if (
            (t < self.diffuser.T)
            and (t != self.diffuser_conf.partial_T)
            and self.preprocess_conf.selfcondition_msaprev
        ):

            self._log.info("Providing Ab Seq Self Cond")
            msa_prev = self.msa_prev

        else:
            msa_prev = None

        ##################################
        ######## Str Self Cond ###########
        ##################################
        if (t < self.diffuser.T) and (t != self.diffuser_conf.partial_T):

            self._log.info("Providing Ab Str Self Cond")
            xyz_t, t2d, xyz_sc, sc2d = process_selfcond(
                self.prev_pred, t2d, xyz_t, self.ab_conf
            )

            # Correct selfcond now will detect from the checkpoint file whether it
            # should be run or not
            t2d = correct_selfcond(
                t2d,
                self.ab_conf,
                self.ab_item.inputs.xyz_true,
                self.ab_item.loop_mask,
                self.ab_item.target_mask,
                self.ab_item.interchain_mask,
            )

        else:
            # The non-selfcond step for antibodies is to just leave the input as-is
            sc2d, xyz_sc = process_init_selfcond(t2d, xyz_t, self.ab_conf)

        print("Monitoring target centering")
        self._log.info(xyz_t[0, 0, self.diffusion_mask, 1].mean(dim=0))
        self._log.info(xt_in[0, self.diffusion_mask, 1].mean(dim=0))

        px0 = xt_in
        # for rec in range(self.recycle_schedule[t-1]):
        msa_prev, pair_prev, px0, state_prev, alpha, logits, plddt = self.model(
            msa_masked,
            msa_full,
            seq_in,
            xt_in,
            idx_pdb,
            t1d=t1d,
            t2d=t2d,
            xyz_t=xyz_t,
            alpha_t=alpha_t,
            msa_prev=msa_prev,
            pair_prev=None,
            state_prev=None,
            t=ms.tensor(t),
            return_infer=True,
            motif_mask=self.diffusion_mask.squeeze(),
        )

        # To permit 'recycling' within a timestep, in a manner akin to how this model was trained
        # Aim is to basically just replace the xyz_t with the model's last px0, and to *not* recycle the state, pair or msa embeddings
        # if rec < self.recycle_schedule[t-1] -1:
        #     zeros = ms.mint.zeros((B,1,L,24,3)).float()
        #     xyz_t = ms.mint.cat((px0.unsqueeze(1),zeros), dim=-2) # [B,T,L,27,3]
        #     t2d   = xyz_to_t2d(xyz_t) # [B,T,L,L,44]
        #     px0=xt_in

        self.prev_pred = ms.mint.clone(px0)
        self.msa_prev = ms.mint.clone(msa_prev)

        # prediction of X0
        _, px0 = self.allatom(ms.mint.argmax(seq_in, dim=-1), px0, alpha)
        px0 = px0.squeeze()[:, :14]

        seq_t_1 = ms.mint.clone(seq_init)
        if t > final_step:
            x_t_1, px0 = self.denoiser.get_next_pose(
                xt=x_t,
                px0=px0,
                t=t,
                diffusion_mask=self.mask_str.squeeze(),
                align_motif=self.inf_conf.align_motif,
                include_motif_sidechains=self.preprocess_conf.motif_sidechain_input,
            )
            self._log.info(
                f"Timestep {t}, input to next step: { seq2chars(ms.mint.argmax(seq_t_1, dim=-1).tolist())}"
            )
        else:
            x_t_1 = ms.mint.clone(px0)

        return px0, x_t_1, seq_t_1, plddt
